// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

//===- Transforms.h - Transformations common to all backends --------------===//
//
// Defines transformations that are common to backends
//
//===----------------------------------------------------------------------===//
#ifndef IREE_COMPILER_CODEGEN_COMMON_TRANSFORMS_H_
#define IREE_COMPILER_CODEGEN_COMMON_TRANSFORMS_H_

#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
namespace iree_compiler {

/// Specifies the number of workgroups to use for a particular entry point
/// function, by updating the `worgroup_count` region in the
/// `hal.executable.entry_point` op for this operation. The method takes a
/// callback function, which computes the workgroup count (x,y,z) given the
/// workload along (x,y,z).
using WorkgroupCountRegionBuilder = std::function<std::array<Value, 3>(
    OpBuilder &b, Location loc, std::array<Value, 3> workload)>;
LogicalResult defineWorkgroupCountRegion(
    OpBuilder &builder, FuncOp funcOp,
    WorkgroupCountRegionBuilder regionBuilder);

/// Using linalg on tensors for dispatch region creation does first-level of
/// tile (fuse and distribute) during dispatch region formation. At that point
/// the workload per workgroup is set to the dynamic value represented by
/// `flow.dispatch.workgroup.size` and is later lowered to
/// `hal.dispatch.workgroup.size`. This method is to materialize the static
/// information of the workload per workgroup determined based on target
/// architecture.  Note that the value of hal.dispatch.workgroup.size is now
/// different after this function is called and represents the actual value used
/// at runtime.
LogicalResult materializeStaticLaunchInformation(
    FuncOp funcOp, ArrayRef<int64_t> workloadPerWorkgroup);

/// Return a fused vector::ContractionOp which represents a patterns such as:
///
/// ```mlir
///    %c0 = vector.constant 0: ...
///    %c = vector.contract %a, %b, %c0: ...
///    %e = add %c, %d: ...
/// ```
///
/// by:
///
/// ```mlir
///    %e = vector.contract %a, %b, %d: ...
/// ```
///
/// Return null if the canonicalization does not apply.
// TODO: This should be a folding of Add into Contract in core but while they
// live in different dialects, it is not possible without unnatural
// dependencies.
vector::ContractionOp canonicalizeContractionAdd(Operation *op);

/// Insert patterns to perform folding of AffineMinOp by matching the pattern
/// generated by tile and distribute. Try to fold a affine.min op by matching
/// the following form:
/// ```
/// scf.for %iv = %lb to %ub step %step
///   %affine.min affine_map<(d0, d1) -> (N, d0 - d1)>(%ub, %iv)
/// ```
/// With N a compile time constant. This operations can be replace by
/// `%cN = constant N : index` if we can prove that %lb, %step and %ub are
/// divisible by N.
void populateAffineMinSCFCanonicalizationPattern(RewritePatternSet &patterns);

using GetMinMaxExprFn =
    std::function<Optional<std::pair<AffineExpr, AffineExpr>>(
        Value value, SmallVectorImpl<Value> &dims,
        SmallVectorImpl<Value> &symbols)>;

/// Insert pattern to remove single iteration loop. The pattern will detect
/// single iteration loops based on the range returned by the lambda
/// |getMinMaxFn| for some know values.
void populateRemoveSingleIterationLoopPattern(RewritePatternSet &patterns,
                                              GetMinMaxExprFn getMinMaxFn);

/// Insert pattern to fold chains of `affine.min` operations.
// TODO: It is not clear what this pattern is doing and should be deprecated.
void populateAffineMinCanonicalizationPattern(RewritePatternSet &patterns);

}  // namespace iree_compiler
}  // namespace mlir

#endif  // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMS_H_
